#!/usr/bin/env python

"""
This lets you ssh to a group of servers and control them as if they were one.
Each command you enter is sent to each host in parallel. The response of each
host is collected and printed. In normal synchronous mode Hive will wait for
each host to return the shell command line prompt. The shell prompt is used to
sync output.

Usage Examples:
1. You can use host name in config file

    xrun host1 host2 host3

2. You can use regex pattern

    xrun host[1-3]

3. When using meta chars in SHELL, you should put slash before it
    xrun host\*
"""

import sys, os, re, optparse, traceback, types, time, getpass
import pexpect, pxssh
import readline, atexit

from parse_ini import parse_config_file

CMD_HELP="""Hive commands are preceded by a colon : (just think of vi).
the name can use python regex pattern.
:ls 
    list all available hosts

:target name1 name2 name3 ...
    set list of hosts to target commands

:target all
    reset list of hosts to target all hosts in the hive. 

:to name command
    send a command line to the named host. This is similar to :target, but
    sends only one command and does not change the list of targets for future
    commands.

:sync
    set mode to wait for shell prompts after commands are run. This is the
    default. When Hive first logs into a host it sets a special shell prompt
    pattern that it can later look for to synchronize output of the hosts. If
    you 'su' to another user then it can upset the synchronization. If you need
    to run something like 'su' then use the following pattern:

    CMD (? for help) > :async
    CMD (? for help) > sudo su - root
    CMD (? for help) > :prompt
    CMD (? for help) > :sync

:async
    set mode to not expect command line prompts (see :sync). Afterwards
    commands are send to target hosts, but their responses are not read back
    until :sync is run. This is useful to run before commands that will not
    return with the special shell prompt pattern that Hive uses to synchronize.

:refresh
    refresh the display. This shows the last few lines of output from all hosts.
    This is similar to resync, but does not expect the promt. This is useful
    for seeing what hosts are doing during long running commands.

:resync
    This is similar to :sync, but it does not change the mode. It looks for the
    prompt and thus consumes all input from all targetted hosts.

:prompt
    force each host to reset command line prompt to the special pattern used to
    synchronize all the hosts. This is useful if you 'su' to a different user
    where Hive would not know the prompt to match.

:send my text
    This will send the 'my text' wihtout a line feed to the targetted hosts.
    This output of the hosts is not automatically synchronized.

:control X
    This will send the given control character to the targetted hosts.
    For example, ":control c" will send ASCII 3.

:exit
    This will exit the hive shell.

"""

def login ():
    def mask_password(password):
        ''' protect the password '''
        length = len(password)
        if length < 8:
            return '*' * length
        else:
            return password[:3] + '*' * (length-6) + password[-3:]


    # I have to keep a separate list of host names because Python dicts are not ordered.
    # I want to keep the same order as in the CLI_ARGS list.
    host_names = []
    hive_connect_info = {}
    hive = {}
    # build up the list of connection information (hostname, username, password, port)
    for hostname, hostinfo in XTOOLS_CONFIG.items():
        match = False
        for pat in CLI_ARGS:
            if re.search(pat, hostname):
                match = True
                break

        if match:
            ip = hostinfo['host']
            name = "%s(%s)" % (hostname, ip)
            host_names.append(name)
            hive_connect_info[name] = (ip, hostinfo['username'], hostinfo['password'], hostinfo['port'])

    # build up the list of hive connections using the connection information.
    for hostname in host_names:
        hostip, username, password, port = hive_connect_info[hostname]
        port = int(port)
        print 'connect to %s@%s#%d using password:%s' % (username, hostip, port, mask_password(password)),
        try:
            if OPTIONS.verbose:
                fout = file("log_"+hostname, "w")
            hive[hostname] = pxssh.pxssh()
            hive[hostname].login(hostip, username, password, port=port)
            if OPTIONS.verbose:
                hive[hostname].logfile = fout
            print '- OK'
            disconnected = False
        except Exception, e:
            print '- ERROR',
            print str(e)
            print 'Skipping', hostname
            hive[hostname] = None
            host_names.remove(hostname)
             
    return host_names, hive


def dangerous_command(cmd):
    if re.match(r'rm\s?', cmd):
        if re.search(r'\s+-\w*r', cmd):
           return True

    if re.search(r'\s?kill\s+', cmd):
        return True
    return False


def deny_command(cmd):
    if re.match(r'tail\s?', cmd):
        if re.search(r'\s+-\w*f', cmd):
            return True

    if re.match(r'rm\s?', cmd):
        if re.search(r'\s+-\w*i', cmd):
            return True

    if re.match(r'(?:top|man|info)\s?', cmd):
        return True

    if re.match(r'(?:reboot|shutdown)\s?', cmd):
        return True
    return False


def confirm_command(cmd):
    if dangerous_command(cmd):
        print 'DANGEROUS:', cmd
        confirm = raw_input('continue (yes/no) ? ')
        if re.match(r'(?i)(y|yes)$', confirm):
            return True
        return False

    if deny_command(cmd):
        print 'DENIED:', cmd
        return False

    return True


def main ():
    host_names, hive = login()
    synchronous_mode = True
    target_hostnames = host_names[:]
    if len(target_hostnames) == 0:
        print 'there are no host'
        return
    print 'targetting hosts:', ' '.join(target_hostnames)
    while True:
        cmd = raw_input('CMD (? for help) > ')
        cmd = cmd.strip()
        if cmd=='?' or cmd==':help' or cmd==':h':
            print CMD_HELP
            continue
        elif cmd==':refresh':
            refresh (hive, target_hostnames, timeout=0.5)
            for hostname in target_hostnames:
                if hive[hostname] is None:
                    print '/============================================================================='
                    print '| ' + hostname + ' is DEAD'
                    print '\\-----------------------------------------------------------------------------'
                else:
                    print '/============================================================================='
                    print '| ' + hostname
                    print '\\-----------------------------------------------------------------------------'
                    print hive[hostname].before
            print '=============================================================================='
            continue
        elif cmd==':resync':
            resync (hive, target_hostnames, timeout=0.5)
            for hostname in target_hostnames:
                if hive[hostname] is None:
                    print '/============================================================================='
                    print '| ' + hostname + ' is DEAD'
                    print '\\-----------------------------------------------------------------------------'
                else:
                    print '/============================================================================='
                    print '| ' + hostname
                    print '\\-----------------------------------------------------------------------------'
                    print hive[hostname].before
            print '=============================================================================='
            continue
        elif cmd==':sync':
            synchronous_mode = True
            resync (hive, target_hostnames, timeout=0.5)
            continue
        elif cmd==':async':
            synchronous_mode = False
            continue
        elif cmd==':prompt':
            for hostname in target_hostnames:
                try:
                    if hive[hostname] is not None:
                        hive[hostname].set_unique_prompt()
                except Exception, e:
                    print "Had trouble communicating with %s, so removing it from the target list." % hostname
                    print str(e)
                    hive[hostname] = None
            continue
        elif cmd[:5] == ':send':
            cmd, txt = cmd.split(None,1)

            if not confirm_command(txt):
                continue

            for hostname in target_hostnames:
                try:
                    if hive[hostname] is not None:
                        hive[hostname].send(txt)
                except Exception, e:
                    print "Had trouble communicating with %s, so removing it from the target list." % hostname
                    print str(e)
                    hive[hostname] = None
            continue
        elif cmd[:3] == ':to':
            try:
                cmd, pat, txt = cmd.split(None,2)
            except Exception, e:
                print str(e)
                continue
            hostname = None
            for host in host_names:
                if re.search(pat, host):
                    hostname = host
                    break
            if hostname is None:
                print "please input ':ls' command to list all hosts"
                continue
            if hive[hostname] is None:
                print '/============================================================================='
                print '| ' + hostname + ' is DEAD'
                print '\\-----------------------------------------------------------------------------'
                continue

            if not confirm_command(txt):
                continue

            try:
                hive[hostname].sendline (txt)
                hive[hostname].prompt(timeout=2)
                print '/============================================================================='
                print '| ' + hostname
                print '\\-----------------------------------------------------------------------------'
                print hive[hostname].before
            except Exception, e:
                print "Had trouble communicating with %s, so removing it from the target list." % hostname
                print str(e)
                hive[hostname] = None
            continue
        elif cmd[:7] == ':expect':
            cmd, pattern = cmd.split(None,1)
            print 'looking for', pattern
            try:
                for hostname in target_hostnames:
                    if hive[hostname] is not None:
                        hive[hostname].expect(pattern)
                        print hive[hostname].before
            except Exception, e:
                print "Had trouble communicating with %s, so removing it from the target list." % hostname
                print str(e)
                hive[hostname] = None
            continue
        elif cmd[:7] == ':target':
            target_hostnames = cmd.split()[1:]
            if len(target_hostnames) == 0 or target_hostnames[0] == 'all':
                target_hostnames = host_names[:]
            else:
                select_hosts = []
                for hostname in host_names:
                    match = False
                    for pat in target_hostnames:
                        if re.search(pat, hostname.split('(',1)[0]):
                            match = True
                            break
                    if match:
                       select_hosts.append(hostname) 
                target_hostnames = select_hosts
            print 'targetting hosts:\n', '\n'.join(target_hostnames)
            continue
        elif cmd == ':ls' or cmd == ':list':
            print '%-16s %s' % ('name', 'host')
            print '-' * 32
            for hostname in host_names:
                name, host = re.match(r'([^(]+)\((.+)\)', hostname).groups()
                print '%-16s %s' % (name, host)        
            continue
        elif cmd == ':exit' or cmd == ':q' or cmd == ':quit':
            break
        elif cmd[:8] == ':control' or cmd[:5] == ':ctrl' :
            cmd, c = cmd.split(None,1)
            if ord(c)-96 < 0 or ord(c)-96 > 255:
                print '/============================================================================='
                print '| Invalid character. Must be [a-zA-Z], @, [, ], \\, ^, _, or ?'
                print '\\-----------------------------------------------------------------------------'
                continue
            for hostname in target_hostnames:
                try:
                    if hive[hostname] is not None:
                        hive[hostname].sendcontrol(c)
                except Exception, e:
                    print "Had trouble communicating with %s, so removing it from the target list." % hostname
                    print str(e)
                    hive[hostname] = None
            continue
        elif cmd == ':esc':
            for hostname in target_hostnames:
                if hive[hostname] is not None:
                    hive[hostname].send(chr(27))
            continue
        #
        # Run the command on all targets in parallel
        #
        if not confirm_command(cmd):
            continue

        for hostname in target_hostnames:
            try:
                if hive[hostname] is not None:
                    hive[hostname].sendline (cmd)
            except Exception, e:
                print "Had trouble communicating with %s, so removing it from the target list." % hostname
                print str(e)
                hive[hostname] = None

        #
        # print the response for each targeted host.
        #
        if synchronous_mode:
            for hostname in target_hostnames:
                try:
                    if hive[hostname] is None:
                        print '/============================================================================='
                        print '| ' + hostname + ' is DEAD'
                        print '\\-----------------------------------------------------------------------------'
                    else:
                        hive[hostname].prompt(timeout=2)
                        print '/============================================================================='
                        print '| ' + hostname
                        print '\\-----------------------------------------------------------------------------'
                        print hive[hostname].before
                except pexpect.EOF, e:
                    print "Had trouble communicating with %s, so removing it from the target list." % hostname
                    del hive[hostname]
                    if len(hive) == 0:
                        print 'there are no hosts, now exit'
                        sys.exit(0)
                except Exception, e:
                    print "Had trouble communicating with %s, so removing it from the target list." % hostname
                    print str(e)
                    hive[hostname] = None
            print '=============================================================================='
    
def refresh (hive, hive_names, timeout=0.5):

    """This waits for the TIMEOUT on each host.
    """

    # TODO This is ideal for threading.
    for hostname in hive_names:
        hive[hostname].expect([pexpect.TIMEOUT,pexpect.EOF],timeout=timeout)

def resync (hive, hive_names, timeout=2, max_attempts=5):

    """This waits for the shell prompt for each host in an effort to try to get
    them all to the same state. The timeout is set low so that hosts that are
    already at the prompt will not slow things down too much. If a prompt match
    is made for a hosts then keep asking until it stops matching. This is a
    best effort to consume all input if it printed more than one prompt. It's
    kind of kludgy. Note that this will always introduce a delay equal to the
    timeout for each machine. So for 10 machines with a 2 second delay you will
    get AT LEAST a 20 second delay if not more. """

    # TODO This is ideal for threading.
    for hostname in hive_names:
        for attempts in xrange(0, max_attempts):
            if not hive[hostname].prompt(timeout=timeout):
                break


if __name__ == '__main__':
    try:
        start_time = time.time()
        PARSER = optparse.OptionParser(formatter=optparse.TitledHelpFormatter(), 
                usage=globals()['__doc__'], 
                version='1.0 2009-07-13 colinli',
                conflict_handler="resolve")
        PARSER.add_option ('-v', '--verbose', action='store_true', default=False, help='verbose output')
        (OPTIONS, CLI_ARGS) = PARSER.parse_args()
        if len(CLI_ARGS) < 1:
            PARSER.error ('missing argument')
        if OPTIONS.verbose: print time.asctime()

        XTOOLS_CONFIG = parse_config_file()

        main()

        if OPTIONS.verbose: print time.asctime()
        if OPTIONS.verbose: print 'TOTAL TIME IN MINUTES:',
        if OPTIONS.verbose: print (time.time() - start_time) / 60.0
        sys.exit(0)
    except KeyboardInterrupt, e: # Ctrl-C
        sys.exit(1)
    except SystemExit, e: # sys.exit()
        raise e
    except Exception, e:
        print 'ERROR, UNEXPECTED EXCEPTION'
        print str(e)
        traceback.print_exc()
        os._exit(1)

# vi:ts=4:sw=4:expandtab:ft=python:
